package main

import (
	"flag"
	"fmt"
	"github.com/bwmarrin/snowflake"
	"github.com/gorilla/websocket"
	"log"
	"net/http"
	"strings"
)

var addr = flag.String("addr", "localhost:8111", "http service address")

var upgrader = websocket.Upgrader{
	ReadBufferSize:  1024,
	WriteBufferSize: 1024,
	// 解决跨域问题
	CheckOrigin: func(r *http.Request) bool {
		return true
	},
}

// conn map
var connMap = make(map[string][]*websocket.Conn)
var connMapByFD = make(map[string]*websocket.Conn)

func connect(w http.ResponseWriter, r *http.Request) {
	// 完成和Client HTTP >>> WebSocket的协议升级

	c, err := upgrader.Upgrade(w, r, nil)
	if err != nil {
		log.Print("upgrade:", err)
		return
	}
	err = r.ParseForm()
	if err == nil {
		uid := r.FormValue("uid")
		log.Print("ParseForm-uid:", uid)
		connMap[uid] = append(connMap[uid], c)
	}

	node, err := snowflake.NewNode(1)
	if err != nil {
		log.Print("snowflake:", err)
		return
	}
	fd := node.Generate().String()
	connMapByFD[fd] = c

	err = c.WriteMessage(websocket.TextMessage, []byte(fmt.Sprintf("{\"fd\":%s}", fd)))
	if err != nil {
		log.Println("write:", err)
	}
	defer func(c *websocket.Conn) {
		err := c.Close()
		//当连接断开，删除fd
		delete(connMap, fd)
		if err != nil {
			log.Println("connClose:", err)
		}
	}(c)
	for {
		// 接收客户端message
		mt, message, err := c.ReadMessage()
		if err != nil {
			log.Println("read:", err)
			break
		}
		if string(message) == "ping" {
			err := c.WriteMessage(websocket.TextMessage, []byte("pong"))
			if err != nil {
				return
			}
		}

		log.Printf("recv: %s", message)
		log.Printf("recv: %d", mt)
	}
}

func pushAllByFD(w http.ResponseWriter, r *http.Request) {
	err := r.ParseForm()
	if err != nil {
		log.Fatal("parse form error ", err)
	}
	text := r.FormValue("text")
	go func() {
		if text != "" {
			for _, conn := range connMapByFD {
				err := conn.WriteMessage(1, []byte(text))
				if err != nil {
					log.Println("connfd-err:", err.Error())
				}
			}
		}
	}()

	log.Println("write:", r.FormValue("text"))
	_, err = w.Write([]byte("success"))
	if err != nil {
		return
	}
}

// 推送给多个人
func pushByFD(w http.ResponseWriter, r *http.Request) {
	err := r.ParseForm()
	if err != nil {
		log.Fatal("parse form error ", err)
	}
	text := r.FormValue("text")
	fds := r.FormValue("fds")
	fdsArr := strings.Split(fds, ",")
	go func() {
		if text != "" {
			for _, fd := range fdsArr {
				log.Println("push-fd:", fd)
				connfd, ok := connMapByFD[fd]
				if ok {
					err := connfd.WriteMessage(1, []byte(text))
					if err != nil {
						log.Println("connfd-err:", err.Error())
					}
				}

			}
		}
	}()
	log.Println("write:", r.FormValue("text"))
	_, err = w.Write([]byte("success"))
	if err != nil {
		return
	}
}
func pushAll(w http.ResponseWriter, r *http.Request) {
	err := r.ParseForm()
	if err != nil {
		log.Fatal("parse form error ", err)
	}
	text := r.FormValue("text")
	go func() {
		if text != "" {
			for _, conns := range connMap {
				for _, conn := range conns {
					err := conn.WriteMessage(1, []byte(text))
					if err != nil {
						log.Println("connfd-err:", err.Error())
					}
				}

			}
		}
	}()

	log.Println("write:", r.FormValue("text"))
	_, err = w.Write([]byte("success"))
	if err != nil {
		return
	}
}

// 推送给多个人
func push(w http.ResponseWriter, r *http.Request) {
	err := r.ParseForm()
	if err != nil {
		log.Fatal("parse form error ", err)
	}
	text := r.FormValue("text")
	ids := r.FormValue("ids")
	idsArr := strings.Split(ids, ",")
	go func() {
		if text != "" {
			for _, id := range idsArr {
				log.Println("push-fd:", id)
				conns, ok := connMap[id]
				if ok {
					for _, conn := range conns {
						err := conn.WriteMessage(1, []byte(text))
						if err != nil {
							log.Println("connfd-err:", err.Error())
						}
					}

				}

			}
		}
	}()
	log.Println("write:", r.FormValue("text"))
	_, err = w.Write([]byte("success"))
	if err != nil {
		return
	}
}
func main() {
	flag.Parse()
	log.SetFlags(0)
	http.HandleFunc("/connect", connect)
	http.HandleFunc("/pushAll", pushAll)
	http.HandleFunc("/push", push)
	http.HandleFunc("/pushAllByFD", pushAllByFD)
	http.HandleFunc("/pushByFD", pushByFD)
	log.Fatal(http.ListenAndServe(*addr, nil))
}
